import os
import sys

sys.path.append(os.path.dirname(os.path.realpath(__file__)))
from operator import itemgetter
from .layer_utils import *
from .functional import multi_head_attention_forward
from torch.nn.init import xavier_normal_, xavier_uniform_, constant_
from torch.nn.parameter import Parameter
import torch.nn as nn
import torch


# TODO: Replace the MultiHeadAttention layer with the code written for CADLGen (Adding Flash and Memory efficient Attention)
class MultiHeadAttention(nn.Module):
    r"""Allows the model to jointly attend to information
    from different representation subspaces.
    See reference: Attention Is All You Need
    .. math::
        \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
        \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
    Args:
        input_dim: Input dimension of the sequence
        embed_dim: total dimension of the model.
        num_heads: parallel attention heads.
        dropout: a Dropout layer on attn_output_weights. Default: 0.0.
        bias: add bias as module parameter. Default: True.
        add_bias_kv: add bias to the key and value sequences at dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        kdim: total number of features in key. Default: None.
        vdim: total number of features in key. Default: None.
        Note: if kdim and vdim are None, they will be set to embed_dim such that
        query, key, and value have the same number of features.
    Examples::
        >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
        >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
    """

    __annotations__ = {
        "bias_k": torch._jit_internal.Optional[torch.Tensor],
        "bias_v": torch._jit_internal.Optional[torch.Tensor],
    }
    __constants__ = [
        "q_proj_weight",
        "k_proj_weight",
        "v_proj_weight",
        "in_proj_weight",
    ]

    def __init__(
        self,
        input_dim,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        add_bias_kv=False,
        add_zero_attn=False,
        kdim=None,
        vdim=None,
    ):
        super(MultiHeadAttention, self).__init__()
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"

        if self._qkv_same_embed_dim is False:
            self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
            self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
            self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
            self.register_parameter("in_proj_weight", None)
        else:
            self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
            self.register_parameter("q_proj_weight", None)
            self.register_parameter("k_proj_weight", None)
            self.register_parameter("v_proj_weight", None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
        else:
            self.register_parameter("in_proj_bias", None)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
            self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn
        self._reset_parameters()

    def _reset_parameters(self):
        if self._qkv_same_embed_dim:
            xavier_uniform_(self.in_proj_weight)
        else:
            xavier_uniform_(self.q_proj_weight)
            xavier_uniform_(self.k_proj_weight)
            xavier_uniform_(self.v_proj_weight)

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.0)
            constant_(self.out_proj.bias, 0.0)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        # Support loading old MultiheadAttention checkpoints generated by v1.1.0
        if "_qkv_same_embed_dim" not in state:
            state["_qkv_same_embed_dim"] = True

        super(MultiHeadAttention, self).__setstate__(state)

    def forward(
        self,
        query,
        key,
        value,
        key_padding_mask=None,
        need_weights=True,
        attn_mask=None,
        use_3d=False,
    ):
        r"""
        Args:
            query: input Tensor of shape (batch_size,n_samples,n_features) for query generation
            key: input Tensor of shape (batch_size,n_samples,n_features) for key generation
            value: input Tensor of shape (batch_size,n_samples,n_features) for value generation
            key_padding_mask: if provided, specified padding elements in the key will
                be ignored by the attention. This is an binary mask. When the value is True,
                the corresponding value on the attention layer will be filled with -inf.
            need_weights: output attn_output_weights.
            attn_mask: 2D or 3D mask that prevents attention to certain positions. This is an additive mask
                (i.e. the values will be added to the attention layer). A 2D mask will be broadcasted for all
                the batches while a 3D mask allows to specify a different mask for the entries of each batch.
        Shape:
            - Inputs:
            - query: :math:`(B, N, E)` where N is the target sequence length, B is the batch size, E is
              the embedding dimension.
            - key: :math:`(B, N, E)` where N is the target sequence length, B is the batch size, E is
            the embedding dimension.
            - vaue: :math:`(B, N, E)` where N is the target sequence length, B is the batch size, E is
            the embedding dimension.
            - key_padding_mask: :math:`(B, S)`, ByteTensor, where B is the batch size, S is the source sequence length.
            - attn_mask: 2D mask :math:`(N, S)` where N is the target sequence length, S is the source sequence length.
              3D mask :math:`(B*num_heads, N, S)` where B is the batch size, N is the target sequence length,
              S is the source sequence length.
            - Outputs:
            - attn_output: :math:`(N, B, E)` where N is the target sequence length, B is the batch size,
              E is the embedding dimension.
            - attn_output_weights: :math:`(B, N, S)` where B is the batch size,
              N is the target sequence length, S is the source sequence length.
        """

        query = query.transpose(0, 1)  # (N,B,E)
        key = key.transpose(0, 1)  # (N,B,E)
        value = value.transpose(0, 1)  # (N,B,E)

        # Atttention mask is same for all the batches
        if attn_mask is not None and attn_mask.dim() > 2 and not use_3d:
            attn_mask = attn_mask[0]  # (N,S)

        if not self._qkv_same_embed_dim:
            return multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight,
                k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
            )
        else:
            return multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
            )

class CrossAttention(nn.Module):
    """Cross Attention Layer for multi-modal Feature sharing

    Args:

    input_dim_list: List of dimenstion for input embeddings. The first index should be of the input used for query.
    output_dim: Output dimension. Int
    """

    def __init__(
        self,
        input_dim_list,
        output_dim=512,
        num_heads=1,
        query_name="pc",
        context_1_name="cad",
        dropout=0,
        block_level=1,
    ):
        super(CrossAttention, self).__init__()
        self.query_name = query_name
        self.context_1_name = context_1_name
        self.block_level = block_level

        self.mha = MultiHeadAttention(
            input_dim=input_dim_list[0],
            embed_dim=output_dim,
            dropout=dropout,
            num_heads=num_heads,
        )

    def forward(self, X, Y, Z,  attn_mask=None, key_padding_mask=None):
        """
        input:
        X: Input Embedding of shape. (B,N1,E)
        Y: Context Input Embedding of shape (B,N2,E)
        aggregate: "mean","max","sum"
        mask_dict: dictionary containing two keys "key_padding_mask" and "attn_mask"

        output:
        X: Input Embedding of shape. (B,N1,output_dim)
        cross_attention_scores: dict
        """
        attention_output, attention_scores = self.mha(
            X, Y, Z, attn_mask=attn_mask, use_3d=True
        )

        cross_attention_score = {
            f"{self.query_name}_{self.context_1_name}_{self.block_level}": attention_scores
        }

        return attention_output, cross_attention_score

